Linear Discriminant Analysis (LDA) is a powerful technique for dimensionality reduction and classification. It works by finding linear combinations of features that best separate different classes in the data. Unlike Principal Component Analysis (PCA), which focuses on maximizing variance, LDA specifically aims to maximize the separability between classes.
Let’s start with a simple example to understand how LDA works on well-separated classes.
# Function to generate sample data for demonstration
generate_sample_data <- function(n_samples = 300) {
# Generate three classes with different means and same covariance
class1 <- MASS::mvrnorm(n = n_samples/3,
mu = c(50, 60, 70),
Sigma = matrix(c(100, 20, 15, 20, 100, 25, 15, 25, 100), 3, 3))
class2 <- MASS::mvrnorm(n = n_samples/3,
mu = c(80, 40, 50),
Sigma = matrix(c(100, 20, 15, 20, 100, 25, 15, 25, 100), 3, 3))
class3 <- MASS::mvrnorm(n = n_samples/3,
mu = c(30, 90, 80),
Sigma = matrix(c(100, 20, 15, 20, 100, 25, 15, 25, 100), 3, 3))
# Combine data
sample_data <- rbind(
data.frame(feature1 = class1[,1], feature2 = class1[,2], feature3 = class1[,3], class = "Class A"),
data.frame(feature1 = class2[,1], feature2 = class2[,2], feature3 = class2[,3], class = "Class B"),
data.frame(feature1 = class3[,1], feature2 = class3[,2], feature3 = class3[,3], class = "Class C")
)
return(sample_data)
}
# Generate sample data
sample_data <- generate_sample_data()Let’s examine the structure of our sample data:
## feature1 feature2 feature3 class
## 1 52.77124 50.77344 88.78080 Class A
## 2 56.13956 52.65127 77.00099 Class A
## 3 37.88582 50.65953 59.17225 Class A
## 4 48.08763 56.45356 74.08270 Class A
## 5 41.14525 63.27559 71.86210 Class A
## 6 37.98836 50.58005 55.97900 Class A
# Summary statistics by class
sample_data %>%
group_by(class) %>%
summarise(
n = n(),
mean_feature1 = mean(feature1),
mean_feature2 = mean(feature2),
mean_feature3 = mean(feature3),
sd_feature1 = sd(feature1),
sd_feature2 = sd(feature2),
sd_feature3 = sd(feature3)
) %>%
kable(digits = 2)| class | n | mean_feature1 | mean_feature2 | mean_feature3 | sd_feature1 | sd_feature2 | sd_feature3 |
|---|---|---|---|---|---|---|---|
| Class A | 100 | 48.88 | 58.71 | 70.55 | 9.87 | 8.68 | 9.55 |
| Class B | 100 | 80.93 | 40.38 | 49.49 | 10.56 | 9.70 | 9.73 |
| Class C | 100 | 31.93 | 90.32 | 80.97 | 10.37 | 9.71 | 10.77 |
Now let’s visualize the original 3D data:
# Create 3D scatter plot
p1 <- ggplot(sample_data, aes(x = feature1, y = feature2, color = class)) +
geom_point(size = 2, alpha = 0.7) +
labs(title = "Sample Data: Original Features 1 vs 2",
x = "Feature 1", y = "Feature 2") +
theme_minimal() +
theme(legend.position = "bottom",
panel.background = element_rect(fill = "white"))
p2 <- ggplot(sample_data, aes(x = feature1, y = feature3, color = class)) +
geom_point(size = 2, alpha = 0.7) +
labs(title = "Sample Data: Original Features 1 vs 3",
x = "Feature 1", y = "Feature 3") +
theme_minimal() +
theme(legend.position = "bottom",
panel.background = element_rect(fill = "white"))
p3 <- ggplot(sample_data, aes(x = feature2, y = feature3, color = class)) +
geom_point(size = 2, alpha = 0.7) +
labs(title = "Sample Data: Original Features 2 vs 3",
x = "Feature 2", y = "Feature 3") +
theme_minimal() +
theme(legend.position = "bottom",
panel.background = element_rect(fill = "white"))
# Arrange plots
grid.arrange(p1, p2, p3, ncol = 2)Now let’s apply LDA to see how it transforms our data:
# Prepare data for LDA
X_sample <- as.matrix(sample_data[, 1:3])
y_sample <- sample_data$class
# Perform LDA
lda_sample <- lda(X_sample, y_sample)
# Transform data
X_lda_sample <- predict(lda_sample, X_sample)$x
# Display LDA results
cat("Number of classes:", length(lda_sample$lev), "\n")## Number of classes: 3
## Prior probabilities:
## Class A Class B Class C
## 0.3333333 0.3333333 0.3333333
##
## LDA coefficients:
## LD1 LD2
## feature1 -0.07406551 -0.04545376
## feature2 0.06960232 -0.08690477
## feature3 0.03710453 0.06520045
Let’s visualize the LDA transformation:
# Create visualization of LDA projection
lda_plot_data <- data.frame(
LD1 = X_lda_sample[,1],
LD2 = X_lda_sample[,2],
class = y_sample
)
ggplot(lda_plot_data, aes(x = LD1, y = LD2, color = class)) +
geom_point(size = 2, alpha = 0.7) +
stat_ellipse(level = 0.95) +
labs(title = "Sample Data: LDA Projection",
subtitle = "Notice how well the classes are separated after LDA transformation",
x = "Linear Discriminant 1",
y = "Linear Discriminant 2") +
theme_minimal() +
theme(legend.position = "bottom",
panel.background = element_rect(fill = "white"))Now let’s apply LDA to a real-world dataset - the Pokemon dataset. We’ll predict Pokemon types based on their stats.
# Load Pokemon dataset
pokemon <- read.csv("Pokemon-Dataset/pokemon.csv")
# Display basic information
cat("Dataset dimensions:", nrow(pokemon), "rows ×", ncol(pokemon), "columns\n")## Dataset dimensions: 801 rows × 41 columns
## Columns: abilities, against_bug, against_dark, against_dragon, against_electric, against_fairy, against_fight, against_fire, against_flying, against_ghost, against_grass, against_ground, against_ice, against_normal, against_poison, against_psychic, against_rock, against_steel, against_water, attack, base_egg_steps, base_happiness, base_total, capture_rate, classfication, defense, experience_growth, height_m, hp, japanese_name, name, percentage_male, pokedex_number, sp_attack, sp_defense, speed, type1, type2, weight_kg, generation, is_legendary
# Display first few rows
head(pokemon[, c("name", "type1", "type2", "hp", "attack", "defense", "sp_attack", "sp_defense", "speed")])## name type1 type2 hp attack defense sp_attack sp_defense speed
## 1 Bulbasaur grass poison 45 49 49 65 65 45
## 2 Ivysaur grass poison 60 62 63 80 80 60
## 3 Venusaur grass poison 80 100 123 122 120 80
## 4 Charmander fire 39 52 43 60 50 65
## 5 Charmeleon fire 58 64 58 80 65 80
## 6 Charizard fire flying 78 104 78 159 115 100
# Select only Pokemon with single type (no dual types)
single_type_pokemon <- pokemon[is.na(pokemon$type2) | pokemon$type2 == "", ]
cat("Single-type Pokemon count:", nrow(single_type_pokemon), "\n")## Single-type Pokemon count: 384
# Select relevant features for analysis
features <- c("hp", "attack", "defense", "sp_attack", "sp_defense", "speed", "type1")
pokemon_subset <- single_type_pokemon[, features]
# Remove rows with missing values
pokemon_subset <- pokemon_subset[complete.cases(pokemon_subset), ]
cat("Final dataset size:", nrow(pokemon_subset), "rows\n")## Final dataset size: 384 rows
## Type distribution:
##
## bug dark dragon electric fairy fighting fire flying
## 18 9 12 26 16 22 27 1
## ghost grass ground ice normal poison psychic rock
## 9 37 10 12 61 13 35 11
## steel water
## 4 61
Let’s visualize the type distribution:
# Create type distribution plot
type_df <- data.frame(
type = names(type_counts),
count = as.numeric(type_counts)
)
ggplot(type_df, aes(x = reorder(type, count), y = count)) +
geom_bar(stat = "identity", fill = "steelblue", alpha = 0.8) +
coord_flip() +
labs(title = "Pokemon Type Distribution",
subtitle = "Number of Pokemon per type (single-type only)",
x = "Pokemon Type",
y = "Count") +
theme_minimal() +
theme(axis.text.y = element_text(size = 10),
panel.background = element_rect(fill = "white"))Let’s examine the distribution of Pokemon stats:
# Create histograms for each stat
pokemon_numeric <- pokemon_subset[, 1:6] # Select only numeric columns
pokemon_long <- data.frame(
stat = rep(names(pokemon_numeric), each = nrow(pokemon_numeric)),
value = as.vector(as.matrix(pokemon_numeric))
)
ggplot(pokemon_long, aes(x = value, fill = stat)) +
geom_histogram(bins = 30, alpha = 0.7) +
facet_wrap(~stat, scales = "free", ncol = 2) +
labs(title = "Distribution of Pokemon Stats",
x = "Stat Value", y = "Frequency") +
theme_minimal() +
theme(legend.position = "none",
panel.background = element_rect(fill = "white"))# Calculate correlation matrix
numeric_features <- pokemon_subset[, 1:6]
cor_matrix <- cor(numeric_features)
# Create correlation plot
corrplot(cor_matrix, method = "color", type = "upper",
addCoef.col = "black", tl.col = "black", tl.srt = 45,
title = "Pokemon Stats Correlation Matrix",
mar = c(0,0,2,0))# Prepare data for LDA
X <- as.matrix(pokemon_subset[, 1:6]) # Features
y <- pokemon_subset$type1 # Target variable
# Normalize data (recommended for LDA)
X_scaled <- scale(X)
# Perform LDA
lda_model <- lda(X_scaled, y)
cat("LDA model fitted successfully\n")## LDA model fitted successfully
## Number of classes: 18
## Prior probabilities:
## bug dark dragon electric fairy fighting
## 0.046875000 0.023437500 0.031250000 0.067708333 0.041666667 0.057291667
## fire flying ghost grass ground ice
## 0.070312500 0.002604167 0.023437500 0.096354167 0.026041667 0.031250000
## normal poison psychic rock steel water
## 0.158854167 0.033854167 0.091145833 0.028645833 0.010416667 0.158854167
# Extract coefficients
coef_matrix <- lda_model$scaling
feature_names <- c("HP", "Attack", "Defense", "Sp. Attack", "Sp. Defense", "Speed")
# Create coefficient heatmap
coef_df <- data.frame(
feature = rep(feature_names, ncol(coef_matrix)),
discriminant = rep(paste0("LD", 1:ncol(coef_matrix)), each = length(feature_names)),
coefficient = as.vector(coef_matrix)
)
ggplot(coef_df, aes(x = discriminant, y = feature, fill = coefficient)) +
geom_tile() +
scale_fill_gradient2(low = "blue", mid = "white", high = "red",
midpoint = 0, name = "Coefficient") +
labs(title = "LDA Coefficients Heatmap",
subtitle = "Shows how each feature contributes to each discriminant",
x = "Linear Discriminants",
y = "Features") +
theme_minimal() +
theme(axis.text.x = element_text(angle = 45, hjust = 1),
panel.background = element_rect(fill = "white"))# Calculate feature importance (sum of absolute coefficients across discriminants)
feature_importance <- rowSums(abs(coef_matrix))
importance_df <- data.frame(
feature = feature_names,
importance = feature_importance
)
ggplot(importance_df, aes(x = reorder(feature, importance), y = importance)) +
geom_bar(stat = "identity", fill = "steelblue", alpha = 0.8) +
coord_flip() +
labs(title = "Feature Importance in LDA",
subtitle = "Sum of absolute coefficients across all discriminants",
x = "Features",
y = "Importance Score") +
theme_minimal() +
theme(axis.text.y = element_text(size = 10),
panel.background = element_rect(fill = "white"))# Create data frame for plotting
plot_data <- data.frame(
LD1 = X_lda[,1],
LD2 = X_lda[,2],
LD3 = X_lda[,3],
type = y
)
# Plot first two discriminants
p1 <- ggplot(plot_data, aes(x = LD1, y = LD2, color = type)) +
geom_point(alpha = 0.6, size = 2) +
stat_ellipse(level = 0.95, alpha = 0.3) +
labs(title = "Pokemon Types: LDA Projection (LD1 vs LD2)",
x = "Linear Discriminant 1",
y = "Linear Discriminant 2") +
theme_minimal() +
theme(legend.position = "bottom",
legend.text = element_text(size = 8),
panel.background = element_rect(fill = "white"))
# Plot first vs third discriminant
p2 <- ggplot(plot_data, aes(x = LD1, y = LD3, color = type)) +
geom_point(alpha = 0.6, size = 2) +
stat_ellipse(level = 0.95, alpha = 0.3) +
labs(title = "Pokemon Types: LDA Projection (LD1 vs LD3)",
x = "Linear Discriminant 1",
y = "Linear Discriminant 3") +
theme_minimal() +
theme(legend.position = "bottom",
legend.text = element_text(size = 8),
panel.background = element_rect(fill = "white"))
# Arrange plots
grid.arrange(p1, p2, ncol = 2)# Make predictions
predictions <- predict(lda_model, X_scaled)
# Calculate accuracy
accuracy <- mean(predictions$class == y)
cat("Overall accuracy:", round(accuracy * 100, 2), "%\n")## Overall accuracy: 33.59 %
# Create confusion matrix
conf_matrix <- table(Actual = y, Predicted = predictions$class)
# Calculate per-class accuracy
per_class_accuracy <- diag(conf_matrix) / rowSums(conf_matrix)
cat("\nPer-class accuracy:\n")##
## Per-class accuracy:
## bug dark dragon electric fairy fighting fire flying
## 44.44 0.00 0.00 23.08 25.00 40.91 7.41 0.00
## ghost grass ground ice normal poison psychic rock
## 33.33 0.00 20.00 0.00 57.38 0.00 42.86 45.45
## steel water
## 25.00 63.93
Let’s visualize the confusion matrix:
# Create confusion matrix heatmap
conf_df <- as.data.frame(conf_matrix)
conf_df$Actual <- factor(conf_df$Actual, levels = unique(conf_df$Actual))
conf_df$Predicted <- factor(conf_df$Predicted, levels = unique(conf_df$Predicted))
ggplot(conf_df, aes(x = Predicted, y = Actual, fill = Freq)) +
geom_tile() +
scale_fill_gradient(low = "white", high = "red", name = "Count") +
geom_text(aes(label = Freq), color = "black", size = 3) +
labs(title = "Confusion Matrix Heatmap",
subtitle = paste("Overall Accuracy:", round(accuracy * 100, 2), "%"),
x = "Predicted Type",
y = "Actual Type") +
theme_minimal() +
theme(axis.text.x = element_text(angle = 45, hjust = 1),
axis.text.y = element_text(size = 8),
panel.background = element_rect(fill = "white"))Sample Data: LDA successfully separated the three synthetic classes, demonstrating its effectiveness on well-separated data.
Pokemon Classification: The overall accuracy of ~33.6% suggests that Pokemon types are not easily separable using only their base stats.
Feature Importance: Some stats (like Speed and Attack) appear more important for type classification than others.
Class Separability: The overlapping ellipses in the LDA projections indicate that many Pokemon types have similar stat distributions.
Linear Discriminant Analysis provides valuable insights into the Pokemon dataset, revealing both the potential and limitations of linear classification approaches. While the overall accuracy suggests that Pokemon types cannot be perfectly predicted from base stats alone, LDA successfully identifies the most discriminative features and provides a foundation for more sophisticated analysis.
The technique demonstrates its strength in dimensionality reduction and class separability analysis, making it a valuable tool for exploratory data analysis and preprocessing in machine learning pipelines.